{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eB-dvsMI09O4"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "mwvC53CC1K3n"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0WlR2-Y3v1mf"
      },
      "source": [
        "End-to-end Example to show how a pre-trained FLAX model can be downloaded, exported and ran.\n",
        "===============\n",
        "1. The model used is Resnet50, downloaded from huggingface pre-trained models\n",
        "2. Test image is of a CAT\n",
        "3. `orbax-export` API is used to export the JAX Module to a TF Saved Model, along with image pre/post-processing functions.\n",
        "4. TFLite Converter \u0026 Runtime are used to convert the TF Saved Model to .tflite and run on the test image.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zQXWQ7y11eIR"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/jax_conversion/jax_to_tflite_resnet50\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gq1373VwVlKV"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TUOym6hJnTxC"
      },
      "source": [
        "```\n",
        "!pip install orbax-export\n",
        "\n",
        "!pip install tf-nightly\n",
        "\n",
        "!pip install --upgrade jax jaxlib\n",
        "\n",
        "!pip install transformers flax\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hDzBKqEfrafW"
      },
      "source": [
        "## Show Test Image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2IOV4p_jripo"
      },
      "outputs": [],
      "source": [
        "from PIL import Image\n",
        "import jax\n",
        "import requests\n",
        "\n",
        "url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg\"\n",
        "image = Image.open(requests.get(url, stream=True).raw)\n",
        "image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SPcMyI8ERmIB"
      },
      "source": [
        "## Download and test pre-trained Resnet50 FLAX model from HuggingFace"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a__LrY62PUtg"
      },
      "outputs": [],
      "source": [
        "from transformers import ConvNextImageProcessor, FlaxResNetForImageClassification\n",
        "\n",
        "image_processor = ConvNextImageProcessor.from_pretrained(\"microsoft/resnet-50\")\n",
        "model = FlaxResNetForImageClassification.from_pretrained(\"microsoft/resnet-50\")\n",
        "\n",
        "inputs = image_processor(images=image, return_tensors=\"np\")\n",
        "outputs = model(**inputs)\n",
        "logits = outputs.logits\n",
        "\n",
        "# model predicts one of the 1000 ImageNet classes\n",
        "predicted_class_idx = jax.numpy.argmax(logits, axis=-1)\n",
        "print(\"Predicted class:\", model.config.id2label[predicted_class_idx.item()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MKJHRRjGR1rC"
      },
      "source": [
        "# Create a JAX wrapper for the model\n",
        "Wrapper is needed in order to comply with TFLite accepts inputs. TFLite accets a tensor or a tuple-of-tensors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Af2zv1qFSOR3"
      },
      "outputs": [],
      "source": [
        "import flax.linen as nn\n",
        "from transformers import FlaxResNetForImageClassification\n",
        "\n",
        "\n",
        "class Resnet50Wrapper(nn.Module):\n",
        "  pretrained_model_name: str = \"microsoft/resnet-50\"  # Pre-trained model name\n",
        "\n",
        "  def setup(self):\n",
        "    # Initialize the pre-trained ResNet50 model\n",
        "    self.model = FlaxResNetForImageClassification.from_pretrained(\n",
        "        self.pretrained_model_name\n",
        "    )\n",
        "\n",
        "  def __call__(self, inputs):\n",
        "    # Process input images through the ResNet50 model\n",
        "    outputs = self.model(pixel_values=inputs)\n",
        "\n",
        "    # Return logits or directly apply softmax for probabilities (optional)\n",
        "    return outputs.logits"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mrKtRfkhf73M"
      },
      "source": [
        "## Tensorflow image pre-processor function\n",
        "This essentialliy implements the underlying logic of the `ConvNextImageProcessor` class in huggingface transformers:\n",
        "\u003e ```\n",
        "\u003e image_processor = ConvNextImageProcessor.from_pretrained(\"microsoft/resnet-50\")\n",
        "\u003e inputs = image_processor(images=image, return_tensors=\"np\")\n",
        "\u003e ```\n",
        "\n",
        "This utility can be reused later during orbax-export for `tf_preprocessing`.\n",
        "\n",
        "Note: We can perfectly use the result of `ConvNextImageProcessor` to run a TFLite model. But this example would like to showcase how `orbax-export` helps handle input/output pre/post-processing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h85-Qb1PgANA"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "def resnet_image_processor(image_tensor):\n",
        "  # 1. Resize and Cast to Float32\n",
        "  image_resized = tf.image.resize(\n",
        "      image_tensor, (224, 224), method=tf.image.ResizeMethod.BILINEAR\n",
        "  )\n",
        "  image_float = tf.cast(image_resized, tf.float32)\n",
        "\n",
        "  # 2. Normalize (Using TensorFlow Constants)\n",
        "  mean = tf.constant([0.485, 0.456, 0.406])\n",
        "  std = tf.constant([0.229, 0.224, 0.225])\n",
        "  image_normalized = (image_float / 255.0 - mean) / std\n",
        "\n",
        "  # 3. Transpose for Channel-First Format\n",
        "  image_transposed = tf.transpose(image_normalized, perm=[2, 0, 1])\n",
        "\n",
        "  # 4. Add Batch Dimension\n",
        "  return tf.expand_dims(image_transposed, axis=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3GD4Wb2K3VT_"
      },
      "source": [
        "## Initialize the Jax model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2wHvNFpiThVL"
      },
      "outputs": [],
      "source": [
        "# Initialize the JAX Model\n",
        "jax_model = Resnet50Wrapper()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uwP7XjS43M7W"
      },
      "source": [
        "## Validate the wrapped JAX Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gl8mypra3MGw"
      },
      "outputs": [],
      "source": [
        "# Convert the raw image values to RGB tensor\n",
        "raw_image_tensor = tf.convert_to_tensor(np.array(image, dtype=np.float32))\n",
        "\n",
        "# Appy the above TF imape preprocessing to get an input tensor supported by Resnet50\n",
        "input_tensor = resnet_image_processor(raw_image_tensor)\n",
        "\n",
        "# Run the JAX model\n",
        "jax_logits = jax_model.apply({}, input_tensor.numpy())\n",
        "\n",
        "jax_predicted_class_idx = jax.numpy.argmax(jax_logits, axis=-1)\n",
        "print(\"Predicted class:\", model.config.id2label[jax_predicted_class_idx.item()])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lKZst83Mu6TD"
      },
      "outputs": [],
      "source": [
        "raw_image_tensor.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zuAl8ubTTKxZ"
      },
      "source": [
        "# Export to TFLite model and run"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pz61IKnwjhbn"
      },
      "source": [
        "## Export the JAX to TF Saved Model using orbax-export"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cwq1rxEZjpoU"
      },
      "outputs": [],
      "source": [
        "from orbax.export import ExportManager, JaxModule, ServingConfig\n",
        "\n",
        "# Wrap the model params and function into a JaxModule.\n",
        "jax_module = JaxModule({}, jax_model.apply, trainable=False)\n",
        "\n",
        "# Specify the serving configuration and export the model.\n",
        "serving_config = ServingConfig(\n",
        "    \"serving_default\",\n",
        "    input_signature=[tf.TensorSpec([480, 640, 3], tf.float32, name=\"inputs\")],\n",
        "    tf_preprocessor=resnet_image_processor,\n",
        "    tf_postprocessor=lambda x: tf.argmax(x, axis=-1),\n",
        ")\n",
        "\n",
        "export_manager = ExportManager(jax_module, [serving_config])\n",
        "\n",
        "saved_model_dir = \"resnet50_saved_model\"\n",
        "export_manager.save(saved_model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O00xe8MevVLj"
      },
      "source": [
        "### Convert to a TFLite Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Uf2-ZMECmtJ3"
      },
      "outputs": [],
      "source": [
        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
        "tflite_model = converter.convert()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n-fBAd64vZse"
      },
      "source": [
        "### Apply TFLite Runtime API on the `raw_image_tensor`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "unFFby0Zmup-"
      },
      "outputs": [],
      "source": [
        "def run_tflite_model(tflite_model_content, input_tensor):\n",
        "  interpreter = tf.lite.Interpreter(model_content=tflite_model_content)\n",
        "  interpreter.allocate_tensors()\n",
        "\n",
        "  input_details = interpreter.get_input_details()[0]\n",
        "  interpreter.set_tensor(input_details[\"index\"], input_tensor)\n",
        "  interpreter.invoke()\n",
        "\n",
        "  output_details = interpreter.get_output_details()\n",
        "  return interpreter.get_tensor(output_details[0][\"index\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z9Rwk8iFvSBO"
      },
      "outputs": [],
      "source": [
        "output_data = run_tflite_model(tflite_model, raw_image_tensor)\n",
        "print(\"Predicted class:\", model.config.id2label[output_data[0]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NbtCrNLYKpK8"
      },
      "source": [
        "## Export to TF Saved Model without pre/post-processing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qPLmCDniKvV3"
      },
      "outputs": [],
      "source": [
        "saved_model_dir_2 = \"resnet50_saved_model_1\"\n",
        "\n",
        "tf.saved_model.save(\n",
        "    jax_module,\n",
        "    saved_model_dir_2,\n",
        "    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n",
        "        tf.TensorSpec([1, 3, 224, 224], tf.float32, name=\"inputs\")\n",
        "    ),\n",
        "    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wzp8M2ebLYoH"
      },
      "outputs": [],
      "source": [
        "converter_1 = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir_2)\n",
        "tflite_model_1 = converter_1.convert()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Di5ikJYMcKy"
      },
      "outputs": [],
      "source": [
        "output_data_1 = run_tflite_model(tflite_model_1, input_tensor)\n",
        "tfl_predicted_class_idx_1 = tf.argmax(output_data_1, axis=-1).numpy()\n",
        "print(\"Predicted class:\", model.config.id2label[tfl_predicted_class_idx_1[0]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_SwsbUH_Smt7"
      },
      "source": [
        "## Export from TF Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2MbcN1vjSt5E"
      },
      "outputs": [],
      "source": [
        "converter_2 = tf.lite.TFLiteConverter.from_concrete_functions(\n",
        "    [\n",
        "        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n",
        "            tf.TensorSpec([1, 3, 224, 224], tf.float32, name=\"inputs\")\n",
        "        )\n",
        "    ]\n",
        ")\n",
        "tflite_model_2 = converter_2.convert()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WdM3p03AS-xq"
      },
      "outputs": [],
      "source": [
        "output_data_2 = run_tflite_model(tflite_model_2, input_tensor)\n",
        "tfl_predicted_class_idx_2 = tf.argmax(output_data_2, axis=-1).numpy()\n",
        "print(\"Predicted class:\", model.config.id2label[tfl_predicted_class_idx_2[0]])"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "jax_to_tflite_resnet50.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
